{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FPGA实践自选项目\n",
    "# 卷积神经网络实现MNIST书写数字识别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/javascript": [
       "\n",
       "require(['notebook/js/codecell'], function(codecell) {\n",
       "  codecell.CodeCell.options_default.highlight_modes[\n",
       "      'magic_text/x-csrc'] = {'reg':[/^%%microblaze/]};\n",
       "  Jupyter.notebook.events.one('kernel_ready.Kernel', function(){\n",
       "      Jupyter.notebook.get_cells().map(function(cell){\n",
       "          if (cell.cell_type == 'code'){ cell.auto_highlight(); } }) ;\n",
       "  });\n",
       "});\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pynq import Overlay\n",
    "import numpy as np\n",
    "import util\n",
    "from pynq import MMIO\n",
    "from pynq import Xlnk\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "overlay = Overlay(\"./design_1.bit\")\n",
    "\n",
    "xlnk = Xlnk()   \n",
    "xlnk.xlnk_reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 此处分配内存空间为旧版代码，可能需要更改\n",
    "### 适用于PYNQv2.4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_in_b = xlnk.cma_array(shape=(784,), dtype=np.uint8)\n",
    "in_buffer_address_b=img_in_b.physical_address\n",
    "\n",
    "img_in_g = xlnk.cma_array(shape=(784,), dtype=np.uint8)\n",
    "in_buffer_address_g=img_in_g.physical_address\n",
    "\n",
    "img_in_r = xlnk.cma_array(shape=(784,), dtype=np.uint8)\n",
    "in_buffer_address_r=img_in_r.physical_address\n",
    "\n",
    "img_out = xlnk.cma_array(shape=(1,), dtype=np.int32)\n",
    "out_buffer_address=img_out.physical_address\n",
    "\n",
    "\n",
    "# print(\"in_buffer_address_b:\",in_buffer_address_b)\n",
    "# print(\"in_buffer_address_g:\",in_buffer_address_g)\n",
    "# print(\"in_buffer_address_r:\",in_buffer_address_r)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 根据BD的分配地址填写"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "IP_BASE_ADDRESS    =  0x0040000000\n",
    "ADDRESS_RANGE      =  0x80\n",
    "\n",
    "FPGA_img_addr_AP_CTRL =0x00\n",
    "FPGA_img_addr_GIE     =0x04\n",
    "FPGA_img_addr_IER     =0x08\n",
    "FPGA_img_addr_ISR     =0x0c\n",
    "FPGA_img_addr_in_b    =0x10\n",
    "FPGA_img_addr_in_g    =0x20\n",
    "FPGA_img_addr_in_r    =0x30\n",
    "FPGA_img_addr_out     =0x40"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CNN IP DRIVER\n",
    "### 分配地址，开始计算"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def CNN_Init_EX():\n",
    "    # mapping memory\n",
    "    mmio = MMIO(IP_BASE_ADDRESS,ADDRESS_RANGE)\n",
    "    while True:\n",
    "        ap_idle =  (mmio.read(FPGA_img_addr_AP_CTRL)>>2)&0x01\n",
    "        if(ap_idle):\n",
    "            break\n",
    "\n",
    "    mmio.write(FPGA_img_addr_in_b  ,in_buffer_address_b)\n",
    "    mmio.write(FPGA_img_addr_in_g  ,in_buffer_address_g)\n",
    "    mmio.write(FPGA_img_addr_in_r  ,in_buffer_address_r)\n",
    "    mmio.write(FPGA_img_addr_out  ,out_buffer_address)\n",
    "    \n",
    "    mmio.write(FPGA_img_addr_GIE,0)\n",
    "    mmio.write(FPGA_img_addr_AP_CTRL,1)\n",
    "    while True:\n",
    "        ap_done =  (mmio.read(FPGA_img_addr_AP_CTRL)>>1)&0x01\n",
    "        if(ap_done):\n",
    "            break"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 读取MNIST数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data = util.load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(784,)\n",
      "7\n"
     ]
    }
   ],
   "source": [
    "index = 0\n",
    "input_data = test_data[0,:-1]\n",
    "label      = test_data[0,784]\n",
    "print(input_data.shape)\n",
    "print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAADOxJREFUeJzt3V2IXfW5x/Hf76QpiOlFYjUMNpqeogeiYHIcRTCWqLXEWIjFIPWipFAyvUjkFEo44rloLov0hXphYUpDY8mxFdJqFLGx8WBOUIujqHkzMQmpmZi3MkITQdro04tZttM4+7939tva4/P9wDB7r2e9PGzmN2utvfbaf0eEAOTzb3U3AKAehB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFKf6efGbPNxQqDHIsKtzNfRnt/2ctv7bR+0/UAn6wLQX273s/22Z0k6IOkOSeOSXpF0X0TsLSzDnh/osX7s+W+UdDAiDkfEXyX9WtLKDtYHoI86Cf/lko5OeT5eTfsXtkdsj9ke62BbALqs52/4RcSopFGJw35gkHSy5z8macGU51+opgGYAToJ/yuSrrL9RduflfQNSVu70xaAXmv7sD8iztleJ+n3kmZJ2hgRe7rWGYCeavtSX1sb45wf6Lm+fMgHwMxF+IGkCD+QFOEHkiL8QFKEH0iK8ANJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBShB9IivADSRF+ICnCDyRF+IGkCD+QFOEHkiL8QFJtD9EtSbaPSDoj6UNJ5yJiuBtNAei9jsJfuTUi/tyF9QDoIw77gaQ6DX9I2mb7Vdsj3WgIQH90eti/NCKO2b5M0nO234qIHVNnqP4p8I8BGDCOiO6syN4g6WxE/LAwT3c2BqChiHAr87V92G/7Ytuf+/ixpK9K2t3u+gD0VyeH/fMl/c72x+v534h4titdAei5rh32t7QxDvuBnuv5YT+AmY3wA0kRfiApwg8kRfiBpAg/kFQ37upLYdWqVQ1ra9asKS777rvvFusffPBBsb558+Zi/cSJEw1rBw8eLC6LvNjzA0kRfiApwg8kRfiBpAg/kBThB5Ii/EBS3NLbosOHDzesLVy4sH+NTOPMmTMNa3v27OljJ4NlfHy8Ye2hhx4qLjs2NtbtdvqGW3oBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFLcz9+i0j371113XXHZvXv3FuuLFi0q1pcsWVKsL1u2rGHtpptuKi579OjRYn3BggXFeifOnTtXrJ8+fbpYHxoaanvb77zzTrE+k6/zt4o9P5AU4QeSIvxAUoQfSIrwA0kRfiApwg8k1fR+ftsbJX1N0qmIuLaaNk/SbyQtlHRE0r0R8V7Tjc3g+/kH2dy5cxvWmn1GoNn17BtuuKGtnlrRbLyCAwcOFOv79u0r1ufNm9ewtm7duuKyjzzySLE+yLp5P/8vJS0/b9oDkrZHxFWStlfPAcwgTcMfETskTZw3eaWkTdXjTZLu7nJfAHqs3XP++RFxvHp8QtL8LvUDoE86/mx/RETpXN72iKSRTrcDoLva3fOftD0kSdXvU41mjIjRiBiOiOE2twWgB9oN/1ZJq6vHqyU92Z12APRL0/DbfkzSS5L+w/a47W9L+oGkO2y/Lekr1XMAMwjf24+Bdc899xTrjz/+eLG+e/fuhrVbb721uOzExPkXuGYOvrcfQBHhB5Ii/EBShB9IivADSRF+ICku9aE2l112WbG+a9eujpZftWpVw9qWLVuKy85kXOoDUET4gaQIP5AU4QeSIvxAUoQfSIrwA0kxRDdqs3bt2mL90ksvLdbfe6/8bfH79++/4J4yYc8PJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0lxPz966uabb25Ye/7554vLzp49u1hftmxZsb5jx45i/dOK+/kBFBF+ICnCDyRF+IGkCD+QFOEHkiL8QFJN7+e3vVHS1ySdiohrq2kbJK2RdLqa7cGIeKZXTWLmWrFiRcNas+v427dvL9ZfeumltnrCpFb2/L+UtHya6T+JiMXVD8EHZpim4Y+IHZIm+tALgD7q5Jx/ne03bW+0PbdrHQHoi3bD/zNJX5K0WNJxST9qNKPtEdtjtsfa3BaAHmgr/BFxMiI+jIiPJP1c0o2FeUcjYjgihtttEkD3tRV+20NTnn5d0u7utAOgX1q51PeYpGWSPm97XNL3JS2zvVhSSDoi6Ts97BFAD3A/Pzpy0UUXFes7d+5sWLvmmmuKy952223F+osvvlisZ8X9/ACKCD+QFOEHkiL8QFKEH0iK8ANJMUQ3OrJ+/fpifcmSJQ1rzz77bHFZLuX1Fnt+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iKW3pRdNdddxXrTzzxRLH+/vvvN6zdeeedxWX5au72cEsvgCLCDyRF+IGkCD+QFOEHkiL8QFKEH0iK+/mTu+SSS4r1hx9+uFifNWtWsf7MM40HcOY6fr3Y8wNJEX4gKcIPJEX4gaQIP5AU4QeSIvxAUk3v57e9QNKjkuZLCkmjEfFT2/Mk/UbSQklHJN0bEe81WRf38/dZs+vwL7/8crF+/fXXF+uHDh0q1pcvX972smhPN+/nPyfpexGxSNJNktbaXiTpAUnbI+IqSdur5wBmiKbhj4jjEfFa9fiMpH2SLpe0UtKmarZNku7uVZMAuu+CzvltL5S0RNIfJc2PiONV6YQmTwsAzBAtf7bf9hxJWyR9NyL+Yv/ztCIiotH5vO0RSSOdNgqgu1ra89uercngb46I31aTT9oequpDkk5Nt2xEjEbEcEQMd6NhAN3RNPye3MX/QtK+iPjxlNJWSaurx6slPdn99gD0SiuX+pZK+n9JuyR9VE1+UJPn/Y9LukLSnzR5qW+iybq41NdnV199dbH+1ltvdbT+lStXFutPPfVUR+vHhWv1Ul/Tc/6I2Cmp0cpuv5CmAAwOPuEHJEX4gaQIP5AU4QeSIvxAUoQfSIqv7v4UuPLKKxvWtm3b1tG6169fX6w//fTTHa0f9WHPDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJcZ3/U2BkpPG3pF1xxRUdrfuFF14o1pt9HwQGF3t+ICnCDyRF+IGkCD+QFOEHkiL8QFKEH0iK6/wzwC233FKs33///X3qBJ8m7PmBpAg/kBThB5Ii/EBShB9IivADSRF+IKmm1/ltL5D0qKT5kkLSaET81PYGSWskna5mfTAinulVo5ktXbq0WJ8zZ07b6z506FCxfvbs2bbXjcHWyod8zkn6XkS8Zvtzkl61/VxV+0lE/LB37QHolabhj4jjko5Xj8/Y3ifp8l43BqC3Luic3/ZCSUsk/bGatM72m7Y32p7bYJkR22O2xzrqFEBXtRx+23MkbZH03Yj4i6SfSfqSpMWaPDL40XTLRcRoRAxHxHAX+gXQJS2F3/ZsTQZ/c0T8VpIi4mREfBgRH0n6uaQbe9cmgG5rGn7blvQLSfsi4sdTpg9Nme3rknZ3vz0AvdLKu/03S/qmpF22X6+mPSjpPtuLNXn574ik7/SkQ3TkjTfeKNZvv/32Yn1iYqKb7WCAtPJu/05JnqbENX1gBuMTfkBShB9IivADSRF+ICnCDyRF+IGk3M8hlm0znjPQYxEx3aX5T2DPDyRF+IGkCD+QFOEHkiL8QFKEH0iK8ANJ9XuI7j9L+tOU55+vpg2iQe1tUPuS6K1d3eztylZn7OuHfD6xcXtsUL/bb1B7G9S+JHprV129cdgPJEX4gaTqDv9ozdsvGdTeBrUvid7aVUtvtZ7zA6hP3Xt+ADWpJfy2l9veb/ug7Qfq6KER20ds77L9et1DjFXDoJ2yvXvKtHm2n7P9dvV72mHSauptg+1j1Wv3uu0VNfW2wPb/2d5re4/t/6qm1/raFfqq5XXr+2G/7VmSDki6Q9K4pFck3RcRe/vaSAO2j0gajojarwnb/rKks5IejYhrq2kPSZqIiB9U/zjnRsR/D0hvGySdrXvk5mpAmaGpI0tLulvSt1Tja1fo617V8LrVsee/UdLBiDgcEX+V9GtJK2voY+BFxA5J54+asVLSpurxJk3+8fRdg94GQkQcj4jXqsdnJH08snStr12hr1rUEf7LJR2d8nxcgzXkd0jaZvtV2yN1NzON+dWw6ZJ0QtL8OpuZRtORm/vpvJGlB+a1a2fE627jDb9PWhoR/ynpTklrq8PbgRST52yDdLmmpZGb+2WakaX/oc7Xrt0Rr7utjvAfk7RgyvMvVNMGQkQcq36fkvQ7Dd7owyc/HiS1+n2q5n7+YZBGbp5uZGkNwGs3SCNe1xH+VyRdZfuLtj8r6RuSttbQxyfYvrh6I0a2L5b0VQ3e6MNbJa2uHq+W9GSNvfyLQRm5udHI0qr5tRu4Ea8jou8/klZo8h3/Q5L+p44eGvT175LeqH721N2bpMc0eRj4N02+N/JtSZdI2i7pbUl/kDRvgHr7laRdkt7UZNCGauptqSYP6d+U9Hr1s6Lu167QVy2vG5/wA5LiDT8gKcIPJEX4gaQIP5AU4QeSIvxAUoQfSIrwA0n9HfeMP6R+37kqAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0xaf4c4bf0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "image = input_data\n",
    "pixels = image.reshape(28,28)\n",
    "pixels = np.stack([pixels,pixels,pixels],axis = 0)\n",
    "pixels = np.transpose(pixels,(1,2,0))\n",
    "plt.imshow(pixels/255, cmap=plt.cm.gray)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 调用IP 送入数据进行推理\n",
    "\n",
    "### 输出推理时间和准确率\n",
    "\n",
    "### num为推理图片数目，上限10000张"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.283 ms/img\n",
      "0.89\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "cnt = 0\n",
    "num = 100\n",
    "start = time.time()\n",
    "for i in range(num):\n",
    "    r          = test_data[i,:-1]\n",
    "    label      = test_data[i,784]\n",
    "    image_array_r = np.array(r,dtype='uint8')\n",
    "    np.copyto(img_in_r,image_array_r)\n",
    "    CNN_Init_EX()\n",
    "    pred_out = np.zeros(1,dtype = np.int32)\n",
    "    pred_out = img_out.copy()\n",
    "#     print(\"The pred is \",pred_out[0])\n",
    "#     print(\"The label is \",label)\n",
    "    if(pred_out[0] == label):\n",
    "        cnt += 1\n",
    "end = time.time()\n",
    "print(round((end-start)/num*1000,3),'ms/img')\n",
    "print(cnt/num)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
